import os

import util.utils as utils
import time
import numpy as np
import torch
import matplotlib.cm as cm
import matplotlib.pyplot as plt

from torch.utils.tensorboard import SummaryWriter

class Logger(object):
    def __init__(self, args) -> None:        
        self.args = args
        self.device = "cuda" if args.cuda else "cpu"
        self.cuda = self.args.cuda
        self._make_dir()
        self.writer = SummaryWriter(self.tensorboard_save_dir)
        self.saturation_logs = []
        self.skewness_logs = []

    def _get_state_id(self, file_type=""):
        if file_type is not None:
            return '_'.join([self.state_id, file_type])
        else:
            return self.state_id

    def _make_dir(self):
        if self.args.pretrained is None:
            self.state_id = utils.make_state_id(self.args)

        else:
            self.state_id = self.args.pretrained.split('/')[-1].replace('.pth.tar', '')
            exist_dir = os.path.join('log', self.state_id)
            os.makedirs(exist_dir, exist_ok=True)

        self.model_save_dir = os.path.join('saved_models')
        self.tensorboard_save_dir = os.path.join('log', self._get_state_id(), 'tensorboard', time.strftime('%Y-%m-%d_%I:%M:%S_%p', time.localtime(time.time())))
        self.acc_save_dir = os.path.join('log', 'acc')
        self.info_save_dir = os.path.join('log', self._get_state_id())

        os.makedirs(self.model_save_dir, exist_ok=True)
        os.makedirs(self.acc_save_dir, exist_ok=True)
        os.makedirs(self.info_save_dir, exist_ok=True)
        if not self.args.evaluate:
            os.makedirs(self.tensorboard_save_dir, exist_ok=True)

    def _save_log(self, save_file_name, content, content_type='acc'):
        with open(save_file_name, 'a') as f:
            if content_type == 'accuracy':
                f.write(str(content)+'\n')

            elif content_type == 'skewness' or content_type == 'saturation':
                for c in content:
                    f.write(format(c, '.8f')+'\n')
                f.write('#### \n')

    def calculate_skewness(self, model, loader):
        model.eval()

        def _calculate_accumulated_statistic(feature, samples, statistic_type, mean=None, var=None, skewness=None):
            for ind, block_output in enumerate(feature):

                # (mb, c, h, w) -> (c, mb, h, w)
                block_output_t = block_output.transpose(0, 1)
                channels = block_output_t.shape[0]

                # (c, mb, h, w) -> (c, mb*h*w)
                block_output_channel = block_output_t.contiguous().view(channels, -1)

                if statistic_type == 'mean':
                    channel_statistic = block_output_channel
                    # (c, mb*h*w) -> (c)
                    batch_statistic = torch.sum(channel_statistic, axis=1, keepdim=True)
    
                    if block_output.ndim == 4:
                        c, mb, h, w = block_output.shape
                        batch_statistic = batch_statistic / (samples * h * w)
                        # for skewness
                        n = samples * h * w
                    else:
                        batch_statistic = batch_statistic / (samples)
                        # for skewness
                        n = samples

                    mean[ind] += batch_statistic

                elif statistic_type == 'var':
                    diff = block_output_channel - mean[ind]
                    channel_statistic = torch.pow(diff, 2.0)

                    # (c, mb*h*w) -> (c)
                    batch_statistic = torch.sum(channel_statistic, axis=1, keepdim=True)
                
                    if block_output.ndim == 4:
                        c, mb, h, w = block_output.shape
                        batch_statistic = batch_statistic / (samples * h * w)
                        # for skewness
                        n = samples * h * w
                    else:
                        batch_statistic = batch_statistic / (samples)
                        # for skewness
                        n = samples

                    var[ind] += batch_statistic

                elif statistic_type == 'skewness':
                    std = torch.pow(torch.sqrt(var[ind]), 3.0)
                    diff = block_output_channel - mean[ind]
                    channel_statistic = torch.pow(diff, 3.0) / std
                    # (c, mb*h*w) -> (c)
                    batch_statistic = torch.sum(channel_statistic, axis=1, keepdim=True)
                
                    if block_output.ndim == 4:
                        c, mb, h, w = block_output.shape
                        batch_statistic = batch_statistic / (samples * h * w)
                        # for skewness
                        n = samples * h * w
                    else:
                        batch_statistic = batch_statistic / (samples)
                        # for skewness
                        n = samples
                    
                    batch_statistic *= np.sqrt(n*(n-1)) / (n-2)
                    skewness[ind] += batch_statistic

        with torch.no_grad():
            mean = []
            var = []
            skewness = []
            samples = len(loader.dataset)

            depth, channels = self.get_depth_channels(model=model)

            for d in range(depth):
                mean.append(torch.zeros((channels[d], 1), device=self.device))
                var.append(torch.zeros((channels[d], 1), device=self.device))
                skewness.append(torch.zeros((channels[d], 1), device=self.device))

            # calculate mean
            for data, target in loader:
                if self.cuda:
                    data, target = data.cuda(), target.cuda()

                output = model.module.get_activation(data, target='activation')

                _calculate_accumulated_statistic(output, samples, 'mean', mean)

            # calculate var
            for data, target in loader:
                if self.cuda:
                    data, target = data.cuda(), target.cuda()

                output = model.module.get_activation(data, target='activation')

                _calculate_accumulated_statistic(output, samples, 'var', mean, var)

            # calculate skewness
            for data, target in loader:
                if self.cuda:
                    data, target = data.cuda(), target.cuda()

                output = model.module.get_activation(data, target='activation')

                _calculate_accumulated_statistic(output, samples, 'skewness', mean, var, skewness)

            block_skewness = np.zeros(len(skewness))

            # average the skewness of channels
            for ind, block_skewness_unit in enumerate(skewness):
                block_skewness[ind] = torch.mean(torch.abs(block_skewness_unit))

            # print block skewness
            if not self.args.channel_logging:
                print("Skewness")
                for skwn in block_skewness:
                    print(format(skwn, '.8f'))
                print()

        return block_skewness

    def calculate_saturation(self, model, loader, empirical):
        model.eval()
        total_samples = len(loader.dataset)

        with torch.no_grad():
            block_sum = None
            block_len = None
            accumulated_minmax = []

            if empirical:
                acti_upper = None
            elif self.args.activation_type == 'lecun':
                acti_upper = 1.7159
            else:
                acti_upper = 1
            
            # Init block depth
            block_len, channels = self.get_depth_channels(model)
            for depth in range(block_len):
                channels_init = torch.zeros(channels[depth], 2, device=self.device)
                channels_init[:, 0] = float("Inf")
                channels_init[:, 1] = float("-Inf")

                accumulated_minmax.append(channels_init)

            # Get channel-wise maximum absolute values.
            if empirical:

                for data, target in loader:
                    if self.cuda:
                        data, target = data.cuda(), target.cuda()

                    channel_minmax = model.module.get_minmax(data, block_output=True, channel_flag=True)
                    
                    for l_ind, (cur_minmax, acc_minmax) in enumerate(zip(channel_minmax, accumulated_minmax)):
                        cur_min, cur_max = cur_minmax[:, 0], cur_minmax[:, 1]
                        acc_min, acc_max = acc_minmax[:, 0], acc_minmax[:, 1]

                        min_stack = torch.stack((cur_min, acc_min), dim=1)
                        max_stack = torch.stack((cur_max, acc_max), dim=1)

                        cur_acc_min, _ = torch.min(min_stack, axis=1)
                        cur_acc_max, _ = torch.max(max_stack, axis=1)

                        accumulated_minmax[l_ind] = torch.stack((cur_acc_min, cur_acc_max), dim=1)

            for data, target in loader:
                if self.cuda:
                    data, target = data.cuda(), target.cuda()
                
                # output: (block_len, mini_batch, channel, height, width)
                output = model.module.get_activation(data, target='block' if empirical else 'activation')
                
                if block_sum is None:
                    # total_frequency/sum: (block_len, 1)
                    block_sum = [torch.zeros(1, device=self.device) for i in range(block_len)]

                for ind, block_output in enumerate(output):
                    # block_output: (mini_batch, channel, height, width)
                    block_output_abs = torch.abs(block_output)

                    block_output_abs_tp = block_output_abs.transpose(0, 1)
                    # block_output_tp: (channel, mini_batch, height, width)
                    h = w = 1
                    if len(block_output_abs_tp.shape) == 4:
                        c, mb, h, w = block_output_abs_tp.shape
                    else:
                        c, mb = block_output_abs_tp.shape
                    flat_output_abs = block_output_abs_tp.contiguous().view(c, -1)
                    # flat_output: (channel, , mb*h*w)
                    
                    if empirical:
                        acti_upper, _ = torch.max(torch.abs(accumulated_minmax[ind]), dim=1, keepdim=True)

                    flat_output_nor = (flat_output_abs) / acti_upper

                    # flat_output_nor_abs: (channel, mb*h*w)
                    flat_output_nor_abs = torch.abs(flat_output_nor)

                    block_sum[ind] += (torch.sum(flat_output_nor_abs) / (total_samples * c * h * w))
            
            block_saturation = [ float(bus) for bus in block_sum ]

            # print block skewness
            if not self.args.channel_logging:
                print("Empirical saturation" if empirical else "Saturation")
                for strt in block_saturation:
                    print(format(strt, '.8f'))
                print()

        return block_saturation

    def skewness(self, model, loader):
        skewness = []
        skewness = self.calculate_skewness(model=model, loader=loader)
        save_file_name = os.path.join(self.info_save_dir, 'skewness')
        self._save_log(save_file_name, skewness, 'skewness')

    def saturation(self, model, loader, empirical=False):
        saturation = self.calculate_saturation(model=model, loader=loader, empirical=empirical)

        if empirical:
            save_file_name = os.path.join(self.info_save_dir, 'emprical_saturation')
        else:
            save_file_name = os.path.join(self.info_save_dir, 'saturation')

        self._save_log(save_file_name, saturation, 'saturation')

    def get_depth_channels(self, model):
        depth = None
        channels = []

        model.eval()

        if 'cifar' in self.args.dataset:
            data = torch.randn((2, 3, 32, 32), device=self.device)
        elif 'shapeset' in self.args.dataset:
            data = torch.randn((2, 1, 32, 32), device=self.device)
        elif 'MNIST' in self.args.dataset:
            data = torch.randn((2, 1, 28, 28), device=self.device)
        elif 'tinyImageNet' == self.args.dataset:
            data = torch.randn((2, 3, 64, 64), device=self.device)
        elif 'ImageNet' == self.args.dataset:
            data = torch.randn((2, 3, 244, 244), device=self.device)

        with torch.no_grad():
            activations = model.module.get_activation(data, target='weights')

        depth = len(activations)
        for block_activation in activations:
            channels.append(block_activation.shape[1])
        
        return depth, channels

    def accuracy_save(self, accuracy):
        save_file_name = os.path.join(self.acc_save_dir, self._get_state_id('log'))
        self._save_log(save_file_name, accuracy, 'accuracy')

    def channel_saturation_skewness_plot(self):
        style_list = ['-', '--', '-.', ':','-', '--', '-.', ':','-', '--', '-.', ':','-', '--', '-.', ':',]
        color_list = cm.rainbow(np.linspace(0, 1, len(self.saturation_logs)))

        for l_ind, l_saturation in enumerate(self.saturation_logs):
            plt.plot(l_saturation, label='Layer '+str(l_ind+1), color=color_list[l_ind], linewidth=0.5, linestyle=style_list[l_ind])
        
        lg = plt.legend(bbox_to_anchor=(1.05, 1.0), loc='upper left')
        plt.ylim(0.0, 1.0)
        plt.ylabel('Saturation')
        plt.xlabel('Epochs')
        save_name = self._get_state_id('saturation.pdf')
        plt.savefig(save_name, bbox_extra_artists=(lg,), bbox_inches='tight')
        plt.cla()

        for l_ind, l_skewness in enumerate(self.skewness_logs):
            plt.plot(l_skewness, label='Layer '+str(l_ind+1), color=color_list[l_ind], linewidth=0.5, linestyle=style_list[l_ind])
        
        lg = plt.legend(bbox_to_anchor=(1.05, 1.0), loc='upper left')
        plt.ylim(bottom=0)
        plt.ylabel('Skewness')
        plt.xlabel('Epochs')
        save_name = self._get_state_id('skewness.pdf')
        plt.savefig(save_name, bbox_extra_artists=(lg,), bbox_inches='tight')
        plt.cla()

    def channel_saturation_skewness_logging_init(self, model):
        block_len, channels = self.get_depth_channels(model)

        self.saturation_logs = [ [] for _ in range(block_len) ]
        self.skewness_logs = [ [] for _ in range(block_len) ]

    def channel_saturation_skewness_logging(self, model, loader):

        skewness = self.calculate_skewness(model, loader)
        saturation = self.calculate_saturation(model, loader, empirical=self.args.empirical_saturation)

        for l_ind, l_saturation in enumerate(saturation):
            self.saturation_logs[l_ind].append(l_saturation)

        for l_ind, l_skewness in enumerate(skewness):
            self.skewness_logs[l_ind].append(l_skewness)

    def tensor_board(self, label, contents, epoch):
        self.writer.add_scalars(label, contents, epoch)

    def state_save(self, model, acc):
        utils.save_state(model, acc, self._get_state_id('pth.tar'))

    def activation_distribution(self, model, loader):
        for data, target in loader:
            if self.cuda:
                data, target = data.cuda(), target.cuda()
            
            with torch.no_grad():
                output = model.module.get_decompose_activation(data)
            ####
            alpha = 0.8
            #channels = 256
            #indexes = np.random.randint(0, channels, 5)
            indexes = [37, 235,  72]
            min_layer, max_layer = 16, 19

            for ind, values in enumerate(output):
                # 5 * 4
                if ind < min_layer:
                    continue
                if ind > max_layer:
                    break

                if self.args.operation_order == 'cab':
                    block_order = ['Convolution Layer', 'Tanh', 'Normalizization', 'Affine']
                elif self.args.operation_order == 'cba':
                    block_order = ['Convolution Layer', 'Normalizization', 'Affine', 'Tanh']

                sampled = values.transpose(0, 1)[indexes, :]

                min_value, max_value = None, None
                for c_ind, one_channel in enumerate(sampled):
                    temp_min = float(torch.min(one_channel))
                    temp_max = float(torch.max(one_channel))

                    if min_value is None or min_value > temp_min:
                        min_value = temp_min

                    if max_value is None or max_value < temp_max:
                        max_value = temp_max

                plt.figure(figsize=(16, 12))
                plt.rc('font', size=20, weight='bold') # 기본 폰트 크기
                plt.rc('axes', labelsize=20)   # x,y축 label 폰트 크기
                plt.rc('xtick', labelsize=20)  # x축 눈금 폰트 크기 
                plt.rc('ytick', labelsize=20)  # y축 눈금 폰트 크기
                plt.rc('legend', fontsize=20)  # 범례 폰트 크기
                plt.rc('figure', titlesize=50) # figure title 폰트 크기                 

                for c_ind, one_channel in enumerate(sampled):
                    flatten_value = one_channel.view(-1)
                    
                    plt.hist(flatten_value.cpu().numpy(), bins=20, density=True, 
                    label=str(indexes[c_ind])+'th channel', alpha=alpha, histtype="step",
                    range=(min_value, max_value), linewidth=2)

                plt.title(block_order[ind-min_layer])
                if block_order[ind-min_layer] == 'Tanh':
                    plt.xlim((-1, 1))

                plt.xlabel('Layer output')
                plt.ylabel('Density')
                plt.legend()

                if self.args.pretrained is not None:
                    save_dir= os.path.join('log', self.args.pretrained.split('/')[-1].replace('.pth.tar', ''))
                else:
                    save_dir = self.info_save_dir

                path = os.path.join(save_dir, 'acti_histo.' +self.args.arch +'.'+ self.args.operation_order + '.'+ str(self.args.test_batch_size)+'.'+ block_order[ind-min_layer].replace(' ', '_') +'.pdf')

                plt.savefig(path, bbox_inches='tight')
                plt.show()
                plt.clf()
                plt.cla()
                plt.close()
            break

    def cosine_similarity(self, diff_class):
        print("Not implemented.")
        return None
        if args.cosine_similarity and not args.diff_class:
            cs = nn.CosineSimilarity(dim=0, eps=1e-6)
            total_sample_pairs = int(args.test_batch_size * (args.test_batch_size - 1) / 2)
            
            for block_output in output:
                cs_sum = 0
                samples = block_output.shape[0]
                flatten_feature = block_output.reshape(samples, -1)

                for i in range(args.test_batch_size):
                    for j in range(i+1, args.test_batch_size):
                        cs_sum += cs(flatten_feature[i], flatten_feature[j])
                
                print(float(cs_sum) / total_sample_pairs)

        elif args.cosine_similarity and args.diff_class:
            cs = nn.CosineSimilarity(dim=0, eps=1e-6)
            total_sample_pairs = args.test_batch_size * args.test_batch_size

            _, another_class_test_loader, _ = utils.get_data_loader(args.dataset+'_35', args.batch_size, args.test_batch_size)

            for data, target in another_class_test_loader:
                if self.cuda:
                    data, target = data.cuda(), target.cuda()

                another_output = model.get_decompose_activation(data)
                break
            
            for b_ind in range(len(output)):
                class_3_block_output = output[b_ind]
                class_5_block_output = another_output[b_ind]

                cs_sum = 0
                samples = class_3_block_output.shape[0]
                flatten_class_3_block_output = class_3_block_output.reshape(samples, -1)
                flatten_class_5_block_output = class_5_block_output.reshape(samples, -1)

                for s_3_ind in range(len(flatten_class_3_block_output)):
                    for s_5_ind in range(len(flatten_class_5_block_output)):

                        cs_sum += cs(flatten_class_3_block_output[s_3_ind], flatten_class_5_block_output[s_5_ind])
                
                print(float(cs_sum) / total_sample_pairs)

    def terminate_logging(self):
        self.writer.close()

    def feature_visualization(self, model, loader):
        processed = []
        with torch.no_grad():
            for data, _ in loader:
                data = data.to(self.device)

                outputs = model.module.get_activation(data, target='block')

                for feature_map in outputs:
                    if feature_map.shape[0] != 1:
                        feature_map = feature_map[2]

                    if len(feature_map.shape) != 3:
                        break
                    
                    gray_scale = feature_map[0, :] #torch.sum(feature_map,0) / feature_map.shape[0]
                    gray_scale = torch.abs(gray_scale)
                    processed.append(gray_scale.data.cpu().numpy())

                fig = plt.figure(figsize=(30, 50))
                
                for i in range(len(processed)):
                    #print(processed[i].shape)
                    a = fig.add_subplot(4, 4, i+1)
                    imgplot = plt.imshow(processed[i])
                    a.axis("off")
                    #a.set_title(names[i].split('(')[0], fontsize=30)
                plt.savefig(str('feature_maps.jpg'), bbox_inches='tight')

                break
                #plt.show()

